import os
from time import time
from contextlib import contextmanager

import numpy as np
import taichi as ti
from tqdm import tqdm

from ti_rt.utils.gauss import GaussBloom

from .shape import Shape
from .material import Material
from .camera import Camera
from .color import lRGB2sRGB
from .helper.log import Logger, confirmYes
from .helper.average import MovingAverage
from .scene import Scene
from .utils import iVec2, Vec3, Vec4
from .config import GAMMA

__all__ = [
    "Screen",
]

spsAverage = MovingAverage(100)


@ti.data_oriented
class Screen:
    """屏幕类"""

    def __init__(
        self,
        scene: Scene,
        camera: Camera,
    ) -> None:
        self.camera = camera
        self.scene = scene
        self.size = self.w, self.h = self.camera.size
        self.bloom = GaussBloom(self.size)
        self.frameCount: ti.Field = ti.field(int, shape=1)
        self.minBounce = ti.field(int, shape=1)
        self.maxBounce = ti.field(int, shape=1)
        self.maxBounce[0] = 16
        self.sampleBuffer: ti.Field = Vec3.field(shape=self.size)
        self.hdrBuffer: ti.Field = Vec3.field(shape=self.size)
        self.image: ti.Field = Vec3.field(shape=self.size)
        self.toBloom = False
        self.renderDepth = False
        self.renderNormal = False
        self.lastCur = 0, 0
        self.exposure = 1.0
        self.gamma = 1 / 2.2
        self.recorder = None

    @property
    def window(self) -> ti.ui.Window:
        """
        如果不需要窗口就不创建窗口
        只在需要的时候创建, canvas 和 GUI 属性同理
        """
        if not hasattr(self, "_window"):
            self._window = ti.ui.Window("Camera", res=self.size)
        return self._window

    @property
    def canvas(self) -> ti.ui.Canvas:
        if not hasattr(self, "_canvas"):
            self._canvas = self.window.get_canvas()
        return self._canvas

    @property
    def GUI(self) -> ti.ui.Gui:
        if not hasattr(self, "_GUI"):
            self._GUI = self.window.get_gui()
        return self._GUI

    @property
    def running(self) -> bool:
        """
        判断窗口运行的同时, 响应键盘消息
        如果只需要检查窗口是否在运行, 使用 self.window.running
        """
        clear, self.lastCur = self.camera.keyBoardMsg(self.window, self.lastCur)
        with self.GUI.sub_window("camera settings", 0.7, 0, 0.3, 0.3) as gui:
            clear |= self.camera.makeSettingPlane(gui)
            self.minBounce[0] = gui.slider_int(
                "min bounce", self.minBounce[0], 0, self.maxBounce[0]
            )
            self.maxBounce[0] = gui.slider_int(
                "max bounce", self.maxBounce[0], self.minBounce[0], 32
            )
            self.toBloom = gui.checkbox("bloom", self.toBloom)
            newRadius = gui.slider_int(
                "bloom radius", self.bloom.radius[0], 1, self.bloom.weight.shape[0] - 1
            )
            if newRadius != self.bloom.radius[0]:
                self.bloom.initWeight(newRadius)
            newShift = gui.slider_float(
                "lambda shift",
                self.scene.lambdaShift[0],
                self.scene.sensor.lambdaMin - self.scene.sensor.lambdaMax,
                self.scene.sensor.lambdaMax - self.scene.sensor.lambdaMin,
            )
            if newShift != self.scene.lambdaShift[0]:
                self.scene.lambdaShift[0] = newShift
                clear = True
            self.scene.lambdaShift[0]
            if gui.button("shoot"):
                self.shoot()
        if clear:
            self.clearBuffer()
        if self.recorder is not None:
            self.recorder()
        return self.window.running and not self.window.is_pressed(ti.ui.ESCAPE)

    @ti.kernel
    def sampleN(self, n: int):
        """
        进行N次采样, 根据一些参数决定渲染方法,
        renderDepth: 渲染深度图,
        renderNormal: 渲染法线图,
        renderRelativistic: 考虑相对论效应
        """
        self.frameCount[0] += n
        for u, v in self.sampleBuffer:
            uv = iVec2(u, v)
            ray = self.camera.rayAt(uv)
            if ti.static(self.renderDepth):
                self.sampleBuffer[uv] += self.scene.closest(ray)
            elif ti.static(self.renderNormal):
                self.sampleBuffer[uv] += self.scene.firstNormal(ray)
            else:
                for t in range(n):
                    self.sampleBuffer[uv] += self.scene.integral(
                        self.camera.rayAt(uv, t),
                        self.minBounce[0],
                        self.maxBounce[0],
                    )
                # if self.frameCount[0] == n and self.minBounce[0] == 0:
                #     self.sampleBuffer[uv] = self.scene.oneHit(ray) * n
                # self.sampleBuffer[uv] = (self.sampleBuffer[uv] + self.scene.oneHit(ray) * n) / 2

    @ti.kernel
    def mapping(self):
        for uv in ti.grouped(self.image):
            self.hdrBuffer[uv] = self.sampleBuffer[uv] / self.frameCount[0]
            self.image[uv] = lRGB2sRGB(self.hdrBuffer[uv] * self.exposure)

    def clearBuffer(self):
        """清除缓冲区"""
        self.frameCount[0] = 0
        self.sampleBuffer.fill(0)
        return self

    def render(self, spp: int = 1) -> None:
        """
        渲染一帧图片并展示,
        """
        self.sampleN(spp)
        self.mapping()
        if self.toBloom:
            self.bloom.produce(self.hdrBuffer, self.image)
        self.canvas.set_image(self.image)
        self.window.show()

    def printInfo(self, spp, dt):
        """
        debug信息分别为:
        spp: 每帧每像素采样数,
        frame count: 已渲染帧数,
        at: 鼠标坐标,
        buffer: 鼠标位置处缓冲区内数据,
        HDR: 鼠标位置处 HDR 颜色数据,
        RGB: 鼠标位置处 RGB 颜色数据,
        """
        sps = spsAverage.next(spp / dt)
        self.GUI.text(f"spp: {spp:.1f}\t sps: {sps:.1f}")
        self.GUI.text(f"frame count: {self.frameCount[0]}")
        x, y = self.window.get_cursor_pos()
        self.GUI.text(f"at: {x:.2f}, {y:.2f}")
        x, y = int(x * self.w), int(y * self.h)
        rgb = self.sampleBuffer[x, y] * self.exposure
        r, g, b = rgb
        self.GUI.text(f"buffer: {r:.3f}, {g:.3f}, {b:.3f}")
        r, g, b = rgb / self.frameCount[0]
        self.GUI.text(f"HDR: {r:.3f}, {g:.3f}, {b:.3f}")
        r, g, b = self.image[x, y]
        self.GUI.text(f"RGB: {r:.3f}, {g:.3f}, {b:.3f}")
        x, y, z, t = self.camera.origin[0]
        self.GUI.text(f"x: {x:.3f} y: {y:.3f} z: {z:.3f} t: {t:.3f}")
        self.GUI.text(f"theta: {self.camera.theta:.3f} phi: {self.camera.phi:.3f}")

    def autoRender(self, fps: int, responseMsg: bool = True) -> None:
        """
        自适应帧率渲染,
        简单来说就是尽可能保持FPS不变,
        会导致计算量较大时噪点增多
        """
        spp, t0 = 1.0, time()
        app = self if responseMsg else self.window
        while app.running:
            self.render(int(spp))
            dt, t0 = time() - t0, time()
            spp = max(1, spp / (dt * fps))
            self.printInfo(spp, dt)
        self.camera.saveSetting()

    def shoot(self, filename: str = None, spp: int = 1024, batch: int = 1) -> None:
        """
        渲染一张照片并储存为文件,
        当 batch 较大时可能会出现明显卡顿
        """
        t0 = time()
        assert not spp % batch
        for _ in tqdm(range(spp // batch), f"rendering [{filename}]", unit="sample"):
            self.sampleN(batch)
        self.mapping()
        if self.toBloom:
            self.bloom.produce(self.hdrBuffer, self.image)
        Logger.log(f"Completed rendering, cost {time() - t0:.3f} s, spp of {spp}")
        t0 = time()
        filename = f"output/{Logger.now()}.png" if filename is None else filename
        ti.tools.imwrite(self.image, filename)
        Logger.log(f"Completed imwrite, cost {time() - t0:.3f} s, shape of {self.size}")

    def staticRender(self, fps: int) -> None:
        """
        和直接渲染的区别在于:
        不会响应鼠标键盘事件,
        关闭的时候会把结果保存到文件中,
        下次打开可继续渲染,
        (注意不会检查相机是否还在原地)
        """
        if os.path.exists("temp/sampleBuffer.npy"):
            with open("temp/sampleBuffer.npy", "rb") as f:
                self.frameCount[0] = np.load(f)[0]
                self.sampleBuffer.from_numpy(np.load(f))

        self.autoRender(fps, responseMsg=False)
        with open("temp/sampleBuffer.npy", "wb") as f:
            np.save(f, np.array([self.frameCount[0]]))
            np.save(f, self.sampleBuffer.to_numpy())
        if confirmYes("是否保存图片为 *.png 文件?"):
            ti.tools.imwrite(self.image, f"output/{Logger.now()}.png")

    def preActivate(self):
        """预激活, 确保计算核已编译, 可以用来估计编译时间"""
        t0 = time()
        self.sampleN(1)
        self.mapping()
        self.clearBuffer()
        Logger.log(f"loaded {Shape.shapesNum[0]} shape")
        Logger.log(f"loaded {Material.materialsNum[0]} material")
        Logger.log(f"compiling take {time() - t0} s")
        return self

    @contextmanager
    def record(self, filepath: str):
        """一个上下文管理器, 在作用域内相机移动时, 相关参数会被记录下来"""
        with open(filepath, 'w') as f:

            def recorder():
                x, y, z, t = self.camera.origin[0]
                f.write(f"{x}, {y}, {z}, {t}, {self.camera.theta}, {self.camera.phi}\n")

            self.recorder = recorder
            yield self.recorder
            self.recorder = None

    def renderFromPath(self, filepath: str, spp: int, fps: int) -> None:
        """把 record 方法记录下来的相机路径渲染成视频"""
        with open(filepath) as f:
            lines = f.readlines()
        video = ti.tools.VideoManager("output/", framerate=fps, automatic_build=False)
        for fn in os.listdir(video.frame_directory):
            if fn.endswith('.png'):
                os.remove(os.path.join(video.frame_directory, fn))
        for line in tqdm(lines):
            x, y, z, t, self.camera.theta, self.camera.phi = map(
                float, line.split(", ")
            )
            self.camera.origin[0] = Vec4(x, y, z, t)
            self.camera.setScreen()
            self.clearBuffer().sampleN(spp)
            self.mapping()
            video.write_frame(self.image)
        video.make_video()
